    def get_next(data, label, batch_size=None):
        assert update_policy in ['batch', 'minibatch', 'single'], \
            'Weight update policy must be one of the following values: `batch`, `minibatch`, `single`!'
        assert len(data) == len(label), \
            'You must provide the same number of training data points and labels in the dataset!'

        if update_policy == 'batch':
            return data, label
        elif update_policy == 'single':
            index = random.randint(0, len(data) - 1)
            return np.expand_dims(data[index], 0), np.expand_dims(label[index], 0)
        else:
            assert batch_size is not None, 'You must provide a batch size to use minibatched weight update policy!'
            used_in_batch = random.sample(range(len(data)), batch_size)
            return data[used_in_batch, :], label[used_in_batch]
